local super = require "Graph"

CategoryGraph = super:new()

CategoryGraph.stackedGroupStyle = "stacked"
CategoryGraph.tightGroupStyle = "tight"
CategoryGraph.wideGroupStyle = "wide"

local handles = {
    PanningHandle:new{
        actionName = "Adjust Axis",
        token = Hook:new(
            function(self, x, y)
                local rect = self:getContentRect()
                local axis = self:getNumericAxis()
                local tx, ty
                if self:getOrientation() == Graph.verticalOrientation then
                    ty = axis:scaled(rect, y)
                else
                    tx = axis:scaled(rect, x)
                end
                return tx, ty
            end,
            function(self, tx, ty, x, y)
                local rect = self:getContentRect()
                local axis = self:getNumericAxis()
                local minValue, maxValue
                if tx and self:getOrientation() == Graph.horizontalOrientation then
                    local draggingPosition = axis:scale(rect, tx)
                    minValue = axis:scaled(rect, rect:minx() + draggingPosition - x)
                    maxValue = axis:scaled(rect, rect:maxx() + draggingPosition - x)
                elseif ty and self:getOrientation() == Graph.verticalOrientation then
                    local draggingPosition = axis:scale(rect, ty)
                    minValue = axis:scaled(rect, rect:miny() + draggingPosition - y)
                    maxValue = axis:scaled(rect, rect:maxy() + draggingPosition - y)
                end
                if minValue and maxValue then
                    axis:setRange(minValue, maxValue)
                end
            end),
        track = Hook:new(
            function(self)
                local rect = self:getContentRect()
                local axis = self:getNumericAxis()
                local originPosition = axis:scale(rect, axis:origin())
                local x1, y1, x2, y2
                if self:getOrientation() == Graph.verticalOrientation then
                    originPosition = math._mid(rect:miny(), originPosition, rect:maxy())
                    x1, y1 = rect:minx(), originPosition
                    x2, y2 = rect:maxx(), originPosition
                else
                    originPosition = math._mid(rect:minx(), originPosition, rect:maxx())
                    x1, y1 = originPosition, rect:miny()
                    x2, y2 = originPosition, rect:maxy()
                end
                return x1, y1, x2, y2
            end,
            nil),
        trackThickness = Hook:new(
            function(self)
                return 8, true
            end,
            nil),
    },
    AxisHandle:new{
        actionName = "Adjust Axis",
        token = Hook:new(
            function(self, x, y)
                local rect = self:getContentRect()
                local axis = self:getNumericAxis()
                local tx, ty
                if self:getOrientation() == Graph.verticalOrientation then
                    ty = axis:scaled(rect, y)
                else
                    tx = axis:scaled(rect, x)
                end
                return tx, ty
            end,
            function(self, tx, ty, x, y)
                local rect = self:getContentRect()
                local axis = self:getNumericAxis()
                local minValue, maxValue
                if tx and self:getOrientation() == Graph.horizontalOrientation then
                    local draggingPosition = axis:scale(rect, tx)
                    minValue = axis:scaled(rect, rect:minx() + draggingPosition - x)
                    maxValue = axis:scaled(rect, rect:maxx() + draggingPosition - x)
                elseif ty and self:getOrientation() == Graph.verticalOrientation then
                    local draggingPosition = axis:scale(rect, ty)
                    minValue = axis:scaled(rect, rect:miny() + draggingPosition - y)
                    maxValue = axis:scaled(rect, rect:maxy() + draggingPosition - y)
                end
                if minValue and maxValue then
                    axis:setRange(minValue, maxValue)
                end
            end),
        track = Hook:new(
            function(self)
                local rect = self:getContentRect()
                local x1, y1, x2, y2
                if self:getOrientation() == Graph.verticalOrientation then
                    x1 = rect.left
                    y1 = rect.bottom
                    x2 = rect.left
                    y2 = rect.top
                else
                    x1 = rect.left
                    y1 = rect.bottom
                    x2 = rect.right
                    y2 = rect.bottom
                end
                return x1, y1, x2, y2
            end,
            nil),
        trackThickness = Hook:new(
            function(self)
                return 8, true
            end,
            nil),
        tokenPositions = Hook:new(
            function(self)
                local rect = self:getContentRect()
                local axis = self:getNumericAxis()
                local originValue = axis:origin()
                local originPosition = axis:scale(rect, originValue)
                local minPosition, maxPosition
                if self:getOrientation() == Graph.verticalOrientation then
                    minPosition, maxPosition = rect:miny(), rect:maxy()
                else
                    minPosition, maxPosition = rect:minx(), rect:maxx()
                end
                local crossingValue = originValue
                if originPosition <= minPosition then
                    crossingValue = axis:scaled(rect, minPosition)
                elseif originPosition >= maxPosition then
                    crossingValue = axis:scaled(rect, maxPosition)
                end
                local majorValues, majorPositions = axis:distribute(rect, crossingValue)
                local tokens, positions = {}, {}
                for index = 1, #majorValues do
                    if majorValues[index] ~= originValue then
                        tokens[#tokens + 1] = majorValues[index]
                        positions[#positions + 1] = majorPositions[index]
                    end
                end
                return tokens, positions
            end,
            function(self, token, position)
                local rect = self:getContentRect()
                local axis = self:getNumericAxis()
                local oldPosition = axis:scale(rect, token)
                local originValue = axis:origin()
                local originPosition = axis:scale(rect, originValue)
                local minPosition, maxPosition
                if self:getOrientation() == Graph.verticalOrientation then
                    minPosition, maxPosition = rect:miny(), rect:maxy()
                else
                    minPosition, maxPosition = rect:minx(), rect:maxx()
                end
                local axisPosition = math._mid(minPosition, originPosition, maxPosition)
                local ratio
                if (oldPosition < axisPosition and position < axisPosition) or (oldPosition > axisPosition and position > axisPosition) then
                    ratio = (oldPosition - axisPosition) / (position - axisPosition)
                else
                    ratio = math.huge
                end
                ratio = math.min(ratio, 20 * math.abs(oldPosition - axisPosition) / (maxPosition - minPosition))
                local maxValue = axis:scaled(rect, axisPosition + ratio * (maxPosition - axisPosition))
                local minValue = axis:scaled(rect, axisPosition - ratio * (axisPosition - minPosition))
                axis:setRange(minValue, maxValue)
            end),
    },
}

local defaults = {
    groupStyle = CategoryGraph.stackedGroupStyle,
}

local nilDefaults = {
    'dataset',
}

function CategoryGraph:new()
    self = super.new(self)
    
    for k, v in pairs(defaults) do
        self:addProperty(k, v)
    end
    for _, k in pairs(nilDefaults) do
        self:addProperty(k)
    end
    
    self._orientationHook = PropertyHook:new(Graph.verticalOrientation)
    self._orientationHook:addObserver(self)
    
    self._categoryAddLayerObserver = function(item)
        item:setDatasetHook(self:getPropertyHook('dataset'))
        item:setOrientationHook(self._orientationHook)
        item:setPositionConstrained(true)
    end
    local layerList = self:getLayerList()
    layerList:addEventObserver('add', self._categoryAddLayerObserver)
    
    return self
end

-- NOTE: Version 1.1.2 and earlier saved a 'padding' property.
function CategoryGraph:unarchivePadding(archived)
end

function CategoryGraph:unarchiveOrientation(archived)
    self._orientationHook:setUndoable(false)
    self._orientationHook:setValue(unarchive(archived))
    self._orientationHook:setUndoable(true)
end

function CategoryGraph:archive()
    local typeName, properties = super.archive(self)
    properties.orientation = self:getOrientation()
    return typeName, properties
end

function CategoryGraph:getHandles()
    return appendtables({}, handles, super.getHandles(self))
end

function CategoryGraph:getAxisCategoryFont(parent)
    return self:getFont(TypographyScheme.categoryFont)
end

function CategoryGraph:getInspectors()
    local list = super.getInspectors(self)
    local inspector = self:createInspector('Dataset', {'dataset'}, 'Dataset')
    inspector:addHook(Hook:new(true), 'ordered')
    list:insert(inspector, 1)
    return list
end

function CategoryGraph:getFontInspectors()
    local list = super.getFontInspectors(self)
    list:add(self:createFontInspector(TypographyScheme.quantityFont, 'Axis Values'))
    list:add(self:createFontInspector(TypographyScheme.categoryFont, 'Axis Categories'))
    list:add(self:createFontInspector(TypographyScheme.labelFont, 'Text Series'))
    return list
end

function CategoryGraph:getAxisValueInspectors()
    local list = super.getAxisValueInspectors(self)
    local inspector
    inspector = Inspector:new{
        title = 'Orientation',
        type = 'Orientation',
    }
    local hook = Hook:new(
        function()
            return self:getOrientation()
        end,
        function(value)
            local categoryAxis = self:getCategoryAxis()
            local numericAxis = self:getNumericAxis()
            self:setOrientation(value)
            if value == Graph.horizontalOrientation then
                self:setHorizontalAxis(numericAxis)
                self:setVerticalAxis(categoryAxis)
            else
                self:setHorizontalAxis(categoryAxis)
                self:setVerticalAxis(numericAxis)
            end
        end)
    self._orientationHook:addObserver(hook)
    inspector:addHook(hook)
    list:add(inspector)
    return list
end

function CategoryGraph:getGroupStyleInspector()
    local inspector = Inspector:new{
        title = 'Groups',
        type = 'GroupStyle',
    }
    local hook = Hook:new(
        function()
            return self:getProperty('groupStyle')
        end,
        function(value)
            self:setProperty('groupStyle', value)
        end)
    inspector:addHook(hook)
    return inspector
end

local function drawThumbnail(canvas, rect, fonts, paints)
    if paints.background then
        canvas:setPaint(paints.background)
            :fill(Path.rect(canvas:metrics():rect(), 3))
    end
    local PADX, PADY = 4, 2
    local hundred = StyledString.new('100', { font = fonts.axisValue })
    local valueRect = hundred:measure()
    rect = rect:inset{ left = valueRect:width() + PADX, bottom = 0, right = valueRect:width() * 2 / 3, top = 0 }
    canvas:setPaint(paints.title)
        :setFont(fonts.title)
        :drawText('Title', rect:midx(), rect:maxy() - fonts.title:ascent(), 0.5)
    rect = rect:inset{ left = 0, bottom = 0, right = 0, top = fonts.title:ascent() + fonts.title:descent() + PADY + valueRect:height() / 2 }
    if fonts.axisTitle then
        canvas:setFont(fonts.axisTitle)
            :drawText('Axis Title', rect:midx(), rect:miny() + fonts.axisTitle:descent(), 0.5)
        rect = rect:inset{ left = 0, bottom = fonts.axisTitle:ascent() + fonts.axisTitle:descent() + PADY, right = 0, top = 0 }
    end
    rect = rect:inset{ left = 0, bottom = fonts.axisCategory:ascent() + fonts.axisCategory:descent() + PADY, right = 0, top = 0 }
    canvas:setPaint(paints.fill)
        :fill(Path.rect(rect))
    local valueHeight = valueRect:height()
    local minx, miny = rect:minx(), rect:miny()
    local width, height = rect:width(), rect:height()
    local labelX, labelY = minx - PADX, miny - PADY - valueRect:maxy()
    local xs = { minx + width * 1 / 6, minx + width * 3 / 6, minx + width * 5 / 6 }
    canvas:setPaint(paints.label)
        :setFont(fonts.axisValue)
        :drawText('0', labelX, miny - valueHeight / 2 - valueRect:miny(), 1)
        :drawText('50', labelX, miny - valueHeight / 2 - valueRect:miny() + height * 1 / 2, 1)
        :drawText(hundred, labelX, miny - valueHeight / 2 - valueRect:miny() + height, 1)
        :setFont(fonts.axisCategory)
        :drawText('Label', xs[1], labelY, 0.5)
        :drawText('Label', xs[2], labelY, 0.5)
        :drawText('Label', xs[3], labelY, 0.5)
    local data = {
        { paint = 'data', fractions = { 0.4, 0.9, 0.6 } },
        { paint = 'series1', fractions = { 0.15, 0.35, 0.3 } },
        { paint = 'series2', fractions = { 0.2, 0.25, 0.2 } },
        { paint = 'series3', fractions = { 0.05, 0.3, 0.1 } },
    }
    local bottomFractions = { 0, 0, 0 }
    for seriesIndex = 1, #data do
        local series = data[seriesIndex]
        if paints[series.paint] then
            canvas:setPaint(paints[series.paint])
            for index = 1, 3 do
                canvas:fill(Path.rect{ left = xs[index] - width / 10, bottom = miny + bottomFractions[index] * height, right = xs[index] + width / 10, top = miny + (bottomFractions[index] + series.fractions[index]) * height })
                bottomFractions[index] = bottomFractions[index] + series.fractions[index]
            end
        end
    end
    canvas:setPaint(paints.axis)
        :stroke(Path.line{ x1 = minx, x2 = minx + width, y1 = miny, y2 = miny })
end

function CategoryGraph:drawTypographySchemePreview(canvas, rect, typographyScheme)
    local SIZE = 12
    local fonts = {
        title = typographyScheme:getFont(TypographyScheme.titleFont, SIZE),
        axisTitle = typographyScheme:getFont(TypographyScheme.subtitleFont, SIZE),
        axisValue = typographyScheme:getFont(TypographyScheme.quantityFont, SIZE),
        axisCategory = typographyScheme:getFont(TypographyScheme.categoryFont, SIZE),
        label = typographyScheme:getFont(TypographyScheme.labelFont, SIZE),
    }
    local paints = {
        title = Color.gray(0, 1),
        label = Color.gray(0, 1),
        axis = Color.gray(0, 0.4),
        fill = Color.invisible,
        data = Color.gray(0, 0.4),
    }
    drawThumbnail(canvas, rect, fonts, paints)
end

function CategoryGraph:drawColorSchemePreview(canvas, rect, colorScheme)
    local SIZE = 12
    local typographyScheme = self:getTypographyScheme()
    local fonts = {
        title = typographyScheme:getFont(TypographyScheme.titleFont, SIZE),
        axisValue = typographyScheme:getFont(TypographyScheme.quantityFont, SIZE),
        axisCategory = typographyScheme:getFont(TypographyScheme.categoryFont, SIZE),
        label = typographyScheme:getFont(TypographyScheme.labelFont, SIZE),
    }
    local paints = {
        background = colorScheme:getPaint(ColorScheme.pageBackgroundPaint),
        title = colorScheme:getPaint(ColorScheme.titlePaint),
        label = colorScheme:getPaint(ColorScheme.labelPaint),
        axis = colorScheme:getPaint(ColorScheme.strokePaint),
        fill = colorScheme:getPaint(ColorScheme.backgroundPaint),
        series1 = colorScheme:getDataSeriesPaint(1, 3),
        series2 = colorScheme:getDataSeriesPaint(2, 3),
        series3 = colorScheme:getDataSeriesPaint(3, 3),
    }
    drawThumbnail(canvas, rect, fonts, paints)
end

local _axisDescriptions = { [true] = 'Values', [false] = 'Categories' }

function CategoryGraph:getHorizontalAxisDescription()
    return _axisDescriptions[self:getOrientation() == Graph.horizontalOrientation]
end

function CategoryGraph:getVerticalAxisDescription()
    return _axisDescriptions[self:getOrientation() == Graph.verticalOrientation]
end

function CategoryGraph:setDataset(dataset)
    self:setProperty('dataset', dataset)
end

function CategoryGraph:getDataset()
    return self:getProperty('dataset')
end

function CategoryGraph:isOrientable()
    return true
end

function CategoryGraph:setOrientation(orientation)
    self._orientationHook:setValue(orientation)
end

function CategoryGraph:getOrientation()
    return self._orientationHook:getValue()
end

function CategoryGraph:setHorizontalAxis(axis)
    if axis:isa(CategoryAxis) then
        axis:setDatasetHook(self:getPropertyHook('dataset'))
    end
    super.setHorizontalAxis(self, axis)
end

function CategoryGraph:setVerticalAxis(axis)
    if axis:isa(CategoryAxis) then
        axis:setDatasetHook(self:getPropertyHook('dataset'))
    end
    super.setVerticalAxis(self, axis)
end

function CategoryGraph:getCategoryAxis()
    if self:getOrientation() == Graph.verticalOrientation then
        return self:getHorizontalAxis()
    else
        return self:getVerticalAxis()
    end
end

function CategoryGraph:getNumericAxis()
    if self:getOrientation() == Graph.verticalOrientation then
        return self:getVerticalAxis()
    else
        return self:getHorizontalAxis()
    end
end

function CategoryGraph:isStacked()
    return (self:getProperty('groupStyle') == CategoryGraph.stackedGroupStyle)
end

function CategoryGraph:getGroupItemSpacing()
    local groupStyle = self:getProperty('groupStyle')
    local spacing = {
        [CategoryGraph.stackedGroupStyle] = -1,
        [CategoryGraph.tightGroupStyle] = -1/3,
        [CategoryGraph.wideGroupStyle] = 1/3,
    }
    return (spacing[groupStyle] or spacing[CategoryGraph.wideGroupStyle])
end

function CategoryGraph:drawLayers(canvas, contentRect, layerList)
    local groupableCount = 0
    if not self:isStacked() then
        for layer in layerList:iter() do
            if layer:isGroupable() then
                groupableCount = groupableCount + 1
            end
        end
    end
    local xaxis, yaxis = self:getHorizontalAxis(), self:getVerticalAxis()
    local intralayerStates = {}
    local groupableIndex = 0
    local groupItemSpacing = self:getGroupItemSpacing()
    for layer in layerList:iter() do
        local layerClass = layer:class()
        Profiler.time(layerClass .. ":draw", function()
            intralayerStates[layerClass] = canvas:pcall(function()
                if layer:isGroupable() then
                    groupableIndex = groupableIndex + 1
                end
                local xScaler = xaxis:getScaler(contentRect, groupItemSpacing, groupableCount, layer:isGroupable() and groupableIndex)
                local yScaler = yaxis:getScaler(contentRect, groupItemSpacing, groupableCount, layer:isGroupable() and groupableIndex)
                local propertySequence = layer:makePropertySequence()
                return layer:draw(canvas, contentRect:copy(), propertySequence, xScaler, yScaler, intralayerStates[layerClass] or {})
            end)
        end)
    end
end

return CategoryGraph
